import torch, numpy
import torchvision
import torch.nn as nn
import torch.nn.functional as F

class KKT_loss(nn.Module):
    def __init__(self, cfg, *args, **kwargs):
        super().__init__()
        guider = kwargs['guider']
        self.device = next(guider.parameters()).device
        self.Lambda = kwargs['Lambda']
        self.num_samples = kwargs['num_samples']
        self.duality_threshold = cfg.Generator.duality_threshold
        self.lagrange_coe = cfg.Generator.lagrange_coe
        self.duality_coe = cfg.Generator.duality_coe
        self.double_label = cfg.Generator.double_label
        self.use_stationarity_loss = cfg.Generator.use_stationarity_loss
        self.use_kl_loss = cfg.Generator.use_kl_loss
        self.use_second_order_stationarity = cfg.Generator.use_second_order_stationarity
        self.first_class_probability = cfg.Generator.first_class_probability
        self.first_class_probability_use_alpha = cfg.Generator.first_class_probability_use_alpha

        parameter_dict = dict(guider.named_parameters())
        self.parameter_name_list = list(parameter_dict.keys())

    def forward(self, x, y, l, pred, alpha, guider):
        if self.double_label:
            first_label = y[:,0]
            second_label = y[:,1]
            first_pred = pred[:,first_label]
            second_pred = pred[:,second_label]
            if l is not None:
                second_l = l[torch.arange(len(y)),second_label]

            if self.use_stationarity_loss:
                lagrange_loss = self.stationarity_loss(pred, first_label, guider)
            else:
                lagrange_loss = self.lagrange_loss(first_pred, second_pred, second_l, alpha, guider)

            if self.first_class_probability_use_alpha:
                coe = numpy.exp(numpy.exp(alpha))
                first_class_probability = coe / (1 + coe)
            else:
                first_class_probability = self.first_class_probability
            first_class_probability = (first_class_probability, 1 - first_class_probability, alpha)
            duality_loss = self.KL_loss(pred, y, first_class_probability)
        else:
            y_pred = pred[:,y]
            values, indices = torch.topk(pred, k=2, dim=1)
            second_pred, second_pred_indices = values[:,1], indices[:,1]
            if l is not None:
                second_l = l[torch.arange(len(y)),second_pred_indices]
            if self.use_stationarity_loss:
                lagrange_loss = self.stationarity_loss(pred, y, guider)
            else:
                lagrange_loss = self.lagrange_loss(y_pred, second_pred, second_l, guider)

            if self.use_kl_loss:
                if self.first_class_probability_use_alpha:
                    coe = numpy.exp(numpy.exp(alpha))
                    first_class_probability = coe / (1 + coe)
                else:
                    first_class_probability = self.first_class_probability
                duality_loss = self.KL_loss(pred, y, first_class_probability)
            else:
                duality_loss = self.duality_loss(y_pred, second_pred, alpha)

        loss = self.lagrange_coe * lagrange_loss + self.duality_coe * duality_loss
        return loss, lagrange_loss, duality_loss

    def duality_loss(self, y_pred, second_pred, alpha):
        loss_upper = F.softplus(y_pred - second_pred - alpha - self.duality_threshold)
        loss_lower = F.softplus( - y_pred + second_pred + alpha)
        loss = loss_upper + loss_lower
        return loss

    def KL_loss(self, Phi, y, first_class_probability):
        # Apply softmax activation to h along the K dimension
        h = F.softmax(Phi, dim=1)
        
        
        if self.double_label:
            # Create a N-by-K zero tensor P
            P = torch.zeros_like(h)
            # Scatter the values of `c` and `1 - c` at the indices `Y` and `c` respectively, along the K dimension of `P`
            p = torch.zeros_like(y).float()
            p[:,0] += first_class_probability[0]
            p[:,1] += first_class_probability[1]
            P = P.scatter(dim=1, index=y, src=p)
        else:
            P = torch.ones_like(h)
            P = P * (1 - first_class_probability) / (h.shape[1] - 1)
            P[torch.arange(P.shape[0]).long(),y] = first_class_probability

        
        # Compute KL divergence between h and P
        kl_div = F.kl_div(h.log(), P, reduction='batchmean')

        return kl_div
    
    def stationarity_loss(self, pred, y, guider):
        
        # Compute cross-entropy loss
        loss = F.cross_entropy(pred, y)

        grads = torch.autograd.grad(loss, guider.parameters(), retain_graph=True, create_graph=True)
        if self.use_second_order_stationarity:
            gradient_norm_squared = sum([torch.sum(g ** 2) for g in grads])
        else:
            gradient_norm_squared = sum([torch.sum(g.abs()) for g in grads])
        
        return gradient_norm_squared

    def lagrange_loss(self, y_pred, second_pred, second_l, alpha, guider):
        # y_pred = pred[:,y]
        # values, indices = torch.topk(pred, k=2, dim=1)
        # second_pred, second_pred_indices = values[:,1], indices[:,1]
        # rhs = (y_pred - second_pred) * l[torch.arange(len(y)),second_pred_indices]
        rhs = (y_pred - second_pred) * second_l
        rhs = rhs.mean()
        parameter_dict = dict(guider.named_parameters())
        parameters = [parameter_dict[name] for name in self.parameter_name_list]
        grad = torch.autograd.grad(
            outputs=rhs,
            inputs=parameters,
            create_graph=True,
            retain_graph=True,
        )
        loss = 0

        Lambda_prime_dict = self.Lambda.Lambda_prime_dict(alpha)

        for name, g in zip(self.parameter_name_list, grad):
            L = Lambda_prime_dict[name]
            p = parameter_dict[name]
            assert p.shape == g.shape
            loss += (L*p.detach().data / self.num_samples - g).pow(2).sum()
        return loss

def total_variation_loss(img, power):
    if len(img.size()) == 4:
        bs_img, c_img, h_img, w_img = img.size()
        tv_h = torch.pow(img[:,:,1:,:]-img[:,:,:-1,:], 2).sum()
        tv_w = torch.pow(img[:,:,:,1:]-img[:,:,:,:-1], 2).sum()
        return (tv_h+tv_w)/(bs_img*c_img*h_img*w_img)
    else: 
        return torch.tensor(0.0).to(device = img.device)

def Multiply_total_variation_loss(img, power):
    if len(img.size()) == 4:
        bs_img, c_img, h_img, w_img = img.size()
        tv_h = (img[:,:,:-1,:-1]-img[:,:,1:,:-1]).abs()
        tv_w = (img[:,:,:-1,:-1]-img[:,:,:-1,1:]).abs()
        loss = tv_w * tv_h
        loss = torch.pow(loss, power).sum()
        return loss/(bs_img*c_img*h_img*w_img)
    else: 
        return torch.tensor(0.0).to(device = img.device)